--- ../../plenoxels/opt/util/dataset_base.py	2024-02-13 12:06:07.930166019 -0500
+++ ../google/plenoxels/opt/util/dataset_base.py	2024-02-13 10:46:24.452645156 -0500
@@ -14,8 +14,9 @@
     c2w: torch.Tensor  # C2W OpenCV poses
     gt: Union[torch.Tensor, List[torch.Tensor]]   # RGB images
     device : Union[str, torch.device]
-
-    def __init__(self):
+    opencv_mode: bool
+    
+    def __init__(self, opencv_mode=True, rays_at_center_pixel=True):
         self.ndc_coeffs = (-1, -1)
         self.use_sphere_bound = False
         self.should_use_background = True # a hint
@@ -23,6 +24,11 @@
         self.scene_center = [0.0, 0.0, 0.0]
         self.scene_radius = [1.0, 1.0, 1.0]
         self.permutation = False
+        # our argument to not fit some pixels
+        self.invalid_pixels = None
+
+        self.opencv_mode = opencv_mode
+        self.rays_at_center_pixel = rays_at_center_pixel
 
     def shuffle_rays(self):
         """
@@ -42,13 +48,18 @@
         true_factor = self.h_full / self.h
         self.intrins = self.intrins_full.scale(1.0 / true_factor)
         yy, xx = torch.meshgrid(
-            torch.arange(self.h, dtype=torch.float32) + 0.5,
-            torch.arange(self.w, dtype=torch.float32) + 0.5,
+            torch.arange(self.h, dtype=torch.float32) + (0.5 if self.rays_at_center_pixel else 0.0),
+            torch.arange(self.w, dtype=torch.float32) + (0.5 if self.rays_at_center_pixel else 0.0),
         )
         xx = (xx - self.intrins.cx) / self.intrins.fx
         yy = (yy - self.intrins.cy) / self.intrins.fy
         zz = torch.ones_like(xx)
-        dirs = torch.stack((xx, yy, zz), dim=-1)  # OpenCV convention
+        if self.opencv_mode:
+            dirs = torch.stack((xx, yy, zz), dim=-1) # OpenCV convention
+        else:
+            # do the same as plenoctrees, which uses opengl convention
+            # see svox/renderer.py:L415
+            dirs = torch.stack((xx, -1*yy, -1*zz), dim=-1) # OpenGL convention
         dirs /= torch.norm(dirs, dim=-1, keepdim=True)
         dirs = dirs.reshape(1, -1, 3, 1)
         del xx, yy, zz
@@ -66,6 +77,12 @@
             origins = origins.view(-1, 3)
             dirs = dirs.view(-1, 3)
             gt = gt.reshape(-1, 3)
+            if self.invalid_pixels is not None:
+                # remove invalid pixels
+                valid = torch.logical_not(self.invalid_pixels.reshape(-1))
+                origins = origins[valid]
+                dirs = dirs[valid]
+                gt = gt[valid]
 
         self.rays_init = Rays(origins=origins, dirs=dirs, gt=gt)
         self.rays = self.rays_init
